//LoadStoreUnit.scala
package mycore

import chisel3._
import chisel3.util._

class LoadStoreUnit extends Module{
  val io = IO(new Bundle{
	  val isa = Input(new ISA)
	  val src1 = Input(UInt(64.W))
	  val src2 = Input(UInt(64.W))
	  val imm = Input(new IMM)

    val Valid = Input(Bool())

    val DataLoad = new DataLoad
    val DataStore = new DataStore

    val load_data = Output(UInt(64.W))
  })

  val loads	= io.isa.LD || io.isa.LW || io.isa.LH || io.isa.LB || io.isa.LWU || io.isa.LHU || io.isa.LBU
  val stores = io.isa.SD || io.isa.SW || io.isa.SH || io.isa.SB

  io.DataLoad.en := io.Valid && loads
  io.DataLoad.addr := io.src1 + io.imm.I
  val b_size = 0.U(3.W)
  val h_size = Mux(io.isa.LH || io.isa.LHU, 1.U, 0.U)
  val w_size = Mux(io.isa.LW || io.isa.LWU, 2.U, 0.U)
  val d_size = Mux(io.isa.LD, 3.U, 0.U)
  val load_size = b_size | h_size | w_size | d_size
  io.DataLoad.size := load_size
  val d_data = io.DataLoad.data
  val w_data = Mux(io.DataLoad.addr(2).asBool(), d_data(63,32), d_data(31,0))
  val h_data = Mux(io.DataLoad.addr(1).asBool(), w_data(31,16), w_data(15,0))
  val b_data = Mux(io.DataLoad.addr(0).asBool(), h_data(15,8), h_data(7,0))
  //val d_data = io.DataLoad.data
  //val w_data = io.DataLoad.data(31,0)
  //val h_data = io.DataLoad.data(15,0)
  //val b_data = io.DataLoad.data(7,0)
  val LD_data = SignExt(io.isa.LD.asUInt, 64)   & d_data
  val LW_data = SignExt(io.isa.LW.asUInt, 64)   & (SignExt(w_data, 64))
  val LH_data = SignExt(io.isa.LH.asUInt, 64)   & (SignExt(h_data, 64))
  val LB_data = SignExt(io.isa.LB.asUInt, 64) 	& (SignExt(b_data, 64))
  val LWU_data = SignExt(io.isa.LWU.asUInt, 64)	& (ZeroExt(w_data, 64))
  val LHU_data = SignExt(io.isa.LHU.asUInt, 64)	& (ZeroExt(h_data, 64))
  val LBU_data = SignExt(io.isa.LBU.asUInt, 64)	& (ZeroExt(b_data, 64))
  io.load_data := LD_data | LW_data | LH_data | LB_data | LWU_data | LHU_data | LBU_data

  io.DataStore.en := io.Valid && stores
  io.DataStore.addr := io.src1 + io.imm.S
  val SD_data = SignExt(io.isa.SD.asUInt, 64) & io.src2
  val SW_data = SignExt(io.isa.SW.asUInt, 64) & Cat(io.src2(31,0), io.src2(31,0))
  val SH_data = SignExt(io.isa.SH.asUInt, 64) & Cat(io.src2(15,0), io.src2(15,0), io.src2(15,0), io.src2(15,0))
  val SB_data = SignExt(io.isa.SB.asUInt, 64) & Cat(io.src2(7,0), io.src2(7,0), io.src2(7,0), io.src2(7,0), io.src2(7,0), io.src2(7,0), io.src2(7,0), io.src2(7,0))
  val store_data = SD_data | SW_data | SH_data | SB_data
  val d_mask = SignExt(io.isa.SD.asUInt, 8) & "b1111_1111".U(8.W)
  val w_mask = SignExt(io.isa.SW.asUInt, 8) & ("b0000_1111".U(8.W) << io.DataStore.addr(2,0))
  val h_mask = SignExt(io.isa.SH.asUInt, 8) & ("b0000_0011".U(8.W) << io.DataStore.addr(2,0))
  val b_mask = SignExt(io.isa.SB.asUInt, 8) & ("b0000_0001".U(8.W) << io.DataStore.addr(2,0))
  val byte_mask = d_mask | w_mask | h_mask | b_mask
  
  //8 bit mask -> 64 bit mask
  val byte_mask_0 = Mux(byte_mask(0).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_1 = Mux(byte_mask(1).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_2 = Mux(byte_mask(2).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_3 = Mux(byte_mask(3).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_4 = Mux(byte_mask(4).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_5 = Mux(byte_mask(5).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_6 = Mux(byte_mask(6).asBool(), "hff".U(8.W), 0.U(8.W))
  val byte_mask_7 = Mux(byte_mask(7).asBool(), "hff".U(8.W), 0.U(8.W))
  val store_mask = Cat(byte_mask_7, byte_mask_6, byte_mask_5, byte_mask_4, byte_mask_3, byte_mask_2, byte_mask_1, byte_mask_0)

  io.DataStore.data := store_data & store_mask
  io.DataStore.mask := byte_mask

}